Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added naps_fibers parameter as output #19

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

JasonMH17
Copy link
Contributor

Changes to carfac.py + commensurate test/benchmark files based on added output that represents naps for each fiber type

Copy link

google-cla bot commented Oct 31, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@JasonMH17
Copy link
Contributor Author

@robsc bit confused as to what has happened with the carfac_bench and whether I have overwritten a previous version of mine or something in the pipeline!!

@JasonMH17
Copy link
Contributor Author

@robsc bit confused as to what has happened with the carfac_bench and whether I have overwritten a previous version of mine or something in the pipeline!!

scrap that...I realize that you had generated a benchmark with "two_cap_with_syn" after I cloned the repo....will make changes accordingly!!

@@ -2352,6 +2354,7 @@ def run_segment(

n_ch = hypers.ears[0].car.n_ch
naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result
naps_fibers = jnp.zeros((n_samp, n_ch, 3, n_ears))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this "3" the 3 from SynDesignParameters.n_classes ? I think it is, right?

If so, is it possible to reference that 3 by name instead of the constant here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

v_recep, ear, weights, state.ears[ear].syn
)
naps_fibers = naps_fibers.at[k, :, :, ear].set(firings)
else:
naps_fibers = naps_fibers.at[k, :, :, ear].set(jnp.zeros([jnp.shape(ihc_out)[0], 3]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about the "3" here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -466,4 +466,4 @@ def bench_jax_util_mapped(state: google_benchmark.State):


if __name__ == '__main__':
google_benchmark.main()
google_benchmark.main()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please restore the newline at end of file here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

copybara-service bot pushed a commit that referenced this pull request Nov 11, 2024
COPYBARA_INTEGRATE_REVIEW=#19 from JasonMH17:master c81cc0f
PiperOrigin-RevId: 695494768
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants